import pandas as pd
import numpy as np
import xgboost as xgb
import string
import matplotlib.pyplot as plt
from copy import deepcopy

def load_xgb(MODEL_PATH, FEATURES_PATH, CORES=5):
    models = xgb.Booster({"nthread":CORES})
    models.load_model(MODEL_PATH)
    features = list(np.load(FEATURES_PATH, allow_pickle = True))
    models.feature_names = features
    return models

def predict_xgb(DATA, MODEL, PREFIX_SEP = "_zl_"):
    dt_dummy = pd.get_dummies(DATA, prefix_sep=PREFIX_SEP)
    features = MODEL.feature_names

    features_ = [x for x in features if x not in dt_dummy.columns]

    for var_ in features_:
        dt_dummy[var_] = 0

    dt_dummy = dt_dummy[features]

    xgb_data = xgb.DMatrix(data=dt_dummy.values,feature_names=features)

    return MODEL.predict(xgb_data)

def predict_logistic(DATA, MODEL, PREFIX_SEP = "_zl_"):
    dt_dummy = pd.get_dummies(DATA, prefix_sep=PREFIX_SEP)
    features = MODEL.feature_names_in_
    
    features_ = [x for x in features if x not in dt_dummy.columns]

    for var_ in features_:
        dt_dummy[var_] = 0
        
    dt_dummy = dt_dummy[features]
    
    return MODEL.predict_proba(dt_dummy)[:, 1]

def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def quantile_cut(dt, var, bin = 5):
    step = int(100/bin)
    range_cut = list(range(step, 100 + step, step))

    index_num = string.ascii_lowercase
    min_ = -np.inf
    for counter_, value_ in enumerate(range_cut):
        header_ = index_num[counter_]
        if value_ == 100:
            max_ = np.inf
        else:
            max_ = int(np.round(np.percentile(a=dt[var], q=value_)))
        title_ = f"{header_}.({str(min_)}, {str(max_)}]"
        index_ = (dt[var] > min_) & (dt[var]  <= max_)
        dt.loc[index_, f"{var}_CLUSTER"] = title_
        min_ = max_

def benchmark_sample_action(actions_matrix, dict_mapping_actions, seed):
    np.random.seed(seed)
        
    r,c = actions_matrix.shape
    actions_matrix = np.maximum(actions_matrix, 0)
    list_actions = list(dict_mapping_actions.keys())
    list_actions_select = []
   
    for i in range(r):
        prob_actions_ = actions_matrix[i, :]
        action_index_ = np.random.choice(a = list_actions, p = prob_actions_)
        action_ = dict_mapping_actions[action_index_]
        list_actions_select.append(action_)
    return list_actions_select

def benchmark_cal_reward_costs(data, model_conversion, budget, seed, apply_null = True):
    np.random.seed(seed)
    
    data["interest_rate_save"] = data["interest_rate"]
    data["interest_rate"] = data["interest_rate_save"] * (1 - data["optim_static_action"]/100)

    data["pred_conv_optim_static_action"] = predict_logistic(data, model_conversion)

    data["conversion_optim_static_action"] = \
    data.apply(lambda x:np.random.rand() < x["pred_conv_optim_static_action"], axis=1) * 1

    data["interest_rate"] = data["interest_rate_save"]

    data["reward"] = data["conversion_optim_static_action"] * data["amount_norm"]
    data["cost2"] = data["conversion_optim_static_action"] * data["discount_base_norm"] * (data["optim_static_action"]/100)
    data["cost1"] = (data["conversion_optim_static_action"] * data["optim_static_action"]/100)/7
    
    data["reward_cumsum"] = data["reward"].cumsum()
    data["cost2_cumsum"] = data["cost2"].cumsum()
    data["cost1_cumsum"] = data["cost1"].cumsum()
    
    if apply_null is True:
        index_constraint_break = (data["cost2_cumsum"] > budget) | (data["cost1_cumsum"] > budget)

        for var_ in ["reward", "cost2", "cost1"]:
            data.loc[index_constraint_break, var_] = 0

        data.loc[index_constraint_break, "optim_static_action"] = -1
    
    return data.drop(["reward_cumsum", "cost2_cumsum", "cost1_cumsum", "interest_rate_save"], axis=1)

def plot_performance(dict_results, T_range, log_transform = False, keep_legend = True, ax = None, sampling = None):
    marker_position = -5000
    
    if sampling is not None:
        step = int(1/sampling)
        index_choose = [x-1 for x in T_range if (x%step)==0]
        marker_position = int(marker_position * sampling)
        
        
    if ax is None:
        pass_ax = False
    else:
        pass_ax = True
        
    list_color_benchmark = ["green"]
    list_color_logistic = ["blue","steelblue", "deepskyblue"]
    list_color_linear = ["brown", "saddlebrown", "sandybrown"]
    
    list_benchmark = [x for x in dict_results.keys() if ("Optimal" in x) & ("std" not in x)]
    list_logistic = [x for x in dict_results.keys() if ("Box C" in x) & ("std" not in x)]
    list_linear = [x for x in dict_results.keys() if ("Box D" in x) & ("std" not in x)]

    
    list_name_strategy = list_linear + list_logistic + list_benchmark

    list_color = list_color_linear + list_color_logistic + list_color_benchmark
    
    list_marker = ["^"]*3 + ["D"]*3 + ["o"] 

    for index_, name_ in enumerate(list_name_strategy):
        color_ = list_color[index_]
        marker_ = list_marker[index_]
        T_range_ = deepcopy(T_range)
        y_ =  deepcopy(dict_results[name_])
        std_ = deepcopy(dict_results[f"{name_} - std"])
        
        ci_ = 2 * std_
        
        upper_ = y_ + ci_
        lower_ = y_ - ci_
        
        if log_transform is True:
            y_ = np.log(y_ + 1)
            lower_ = np.log(lower_ + 1)
            upper_ = np.log(upper_ + 1)
            
        if index_ == 0:
            if pass_ax is False:
                fig, ax = plt.subplots(figsize=(16, 12))                
        
        if sampling is not None:
            T_range_ = T_range[index_choose]
            y_ = y_[index_choose]
            lower_ = lower_[index_choose]
            upper_ = upper_[index_choose]
            
        ax.plot(T_range_, y_, marker = marker_, color = color_, label = name_, markevery=[marker_position])
        ax.fill_between(T_range_, lower_, upper_, color=color_, alpha=.1)
        if keep_legend is True:
            ax.legend(loc='upper left', prop={'size': 8})
    
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)

    if np.max(y_[-1])<=0:    
        ax.set_ylim([-500, 150])
    else:
        ax.set_ylim([-150, 500])
        
    if pass_ax is False:
        return fig, ax
    